旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题

您所在的位置:网站首页 pytorch rfft有问题 旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题

旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题

2024-07-17 08:47| 来源: 网络整理| 查看: 265

旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题 前言一、旧版 pytorch.rfft()函数解释二、新版pytorch.fft.rfft()函数解释三、总结

前言

这两天整理谱池化操作,需要用到傅里叶变换这个函数。后来提升了pytorch的版本以后,发现之前的torch.rfft() 函数在新版的pytorch中使用会报错,后来查阅资料,发现是新版的参数有些变动。 pytorch旧版本(1.7之前)中有一个函数torch.rfft(),但是新版本(1.8、1.9)中被移除了,添加了torch.fft.rfft(),但它跟旧版的函数有了很大的变动,参数进行了一个大的调整。 傅里叶变换的整个过程我并没有搞的十分清晰,尤其是pytorch中的引用,网上对于这个函数解析的资料也十分有限,然后从知乎上参考了一篇文章,将我的问题解决了,感谢这位仁兄。

一、旧版 pytorch.rfft()函数解释 fft = torch.rfft(input, 2, normalized=True, onesided=False) # input 为输入的图片或者向量,dtype=torch.float32,size比如为[1,3,64,64]

参数说明:

input (Tensor) – the input tensor of at least signal_ndim dimensions signal_ndim (int) – the number of dimensions in each signal. signal_ndim can only be 1, 2 or 3 normalized (bool, optional) – controls whether to return normalized results. Default: False onesided (bool, optional) – controls whether to return half of results to avoid redundancy. Default: True 在上述的代码中,signal_ndim=2 因为图像是二维的,normalized=False 说明不进行归一化,onesided=False 则是希望不要减少最后一个维度的大小

在1.7版本torch.rfft中,有一个warning,表示在新版中,要“one-side ouput”的话用torch.fft.rfft(),要“two-side ouput”的话用torch.fft.fft()。这里的one/two side,跟旧版的onesided参数对应,所以我们要的是新版的torch.fft.fft()

需要注意的是,假设输入tensor的维度为 [ N 1 , N 2 , , , , N d ] [N_1,N_2,,,,N_d] [N1?,N2?,,,,Nd?],则输出tensor的维度为 [ N 1 , N 2 , , , , N d , 2 ] [N_1,N_2,,,,N_d,2] [N1?,N2?,,,,Nd?,2] 。最后一个维度2表示复数中的实部、虚部,即 z = a + b i z =a+bi z=a+bi这样的复数,在旧版pytorch中表示为一个二维向量 [ a , b ] [a,b] [a,b] 。

二、新版pytorch.fft.rfft()函数解释

新版官网解释

Getting started with the new torch.fft module is easy whether you are familiar with NumPy’s np.fft module or not. While complete documentation for each function in the module can be found here, a breakdown of what it offers is:

fft, which computes a complex FFT over a single dimension, and ifft, its inversethe more general fftn and ifftn, which support multiple dimensionsThe “real” FFT functions, rfft, irfft, rfftn, irfftn, designed to work with signals that are real-valued in their time domainsThe “Hermitian” FFT functions, hfft and ihfft, designed to work with signals that are real-valued in their frequency domainsHelper functions, like fftfreq, rfftfreq, fftshift, ifftshift, that make it easier to manipulate signals

官网解释链接:https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/

小结:可以看到这里也有rfft,官方文档说是用来处理都是实数的输入。但是它在前面的warning中说了是one-side,而我们要的是two-side。此外实数也可以看作是虚部都为0的复数,所以用fft没问题。 新版的rfft和fft都是用于一维输入,而我们的图像是二维,所以应该用rfft2和fft2。在fft2中,参数dim用来指定用于傅里叶变换的维度,默认(-2,-1),正好对应H、W两个维度。 新版所有的fft都不将复数 z = a + b j z=a+bj z=a+bj 存成二维向量了,而是一个数 [ z = a + b j ] [z=a+bj] [z=a+bj]。所以如果要跟旧版中一样存成二维向量,需要用.real()和.imag()提取复数的实部和虚部,然后用torch.stack()堆到一起,即可。

三、总结

代码变更对比如下:

import torch input = torch.rand(1,3,32,32) # 旧版pytorch.rfft()函数 fft = torch.rfft(input, 2, normalized=True, onesided=False) # 新版 pytorch.fft.rfft2()函数 output = torch.fft.fft2(input, dim=(-2, -1)) output = torch.stack((output.real, output_new.imag), -1)

以上是我的理解,整体理解参考文章如下连接。

知乎:旧版pytorch中torch.rfft在新版本中的对应



【本文地址】


今日新闻


推荐新闻


    CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3